"""
The Quora Duplicate Questions dataset contains questions pairs from Quora (www.quora.com)
along with a label whether the two questions are a duplicate, i.e., have an identical intention.

Example of a duplicate pair:
How do I enhance my English?  AND  How can I become good at English?

Example of a non-duplicate pair:
How are roads named?   AND    How are airport runways named?

More details and the original Quora dataset can be found here:
https://www.quora.com/q/quoradata/First-Quora-Dataset-Release-Question-Pairs
Dataset: http://qim.fs.quoracdn.net/quora_duplicate_questions.tsv

You do not need to run this script. You can download all files from here:
https://sbert.net/datasets/quora-duplicate-questions.zip

This script does the following:
1) After reading the quora_duplicate_questions.tsv, as provided by Quora, we add a transitive closure: If question (A, B) are duplicates and (B, C) are duplicates, than (A, C) must also be a duplicate. We add these missing links.

2) Next, we split sentences into train, dev, and test with a ratio of about 85% / 5% / 10%. In contrast to must other Quora data splits, like the split provided by GLUE, we ensure that the three sets are overlap free, i.e., no sentences in dev / test will appear in the train dataset. In order to achieve three distinct datasets, we pick a sentence and then assign all duplicate sentences to this dataset to that respective set

3) After distributing sentences to the three dataset split, we create files to facilitate 3 different tasks:
    3.1) Classification - Given two sentences, are these a duplicate? This is identical to the original Quora task and the task in GLUE, but with the big difference that sentences in dev / test have not been seen in train.
    3.2) Duplicate Question Mining - Given a large set of questions, identify all duplicates. The dev set consists of about 50k questions, the test set of about 100k sentences.
    3.3) Information Retrieval - Given a question as query, find in a large corpus (~350k questions) the duplicates of the query question.


The output consists of the following files:

quora_duplicate_questions.tsv - Original file provided by Quora (https://www.quora.com/q/quoradata/First-Quora-Dataset-Release-Question-Pairs)

classification/
    train/dev/test_pairs.tsv - Distinct sets of question pairs with label for duplicate / non-duplicate. These splits can be used for sentence pair classification tasks

duplicate-mining/ - Given a large set of questions, find all duplicates.
    _corpus.tsv - Large set of sentences
    _duplicates.tsv - All duplicate questions in the respective corpus.tsv

information-retrieval/  - Given a large corpus of questions, find the duplicates for a given query
    corpus.tsv - This file will be used for train/dev/test. It contains all questions in the corpus
    dev/test-queries.tsv - Queries and the respective duplicate questions (QIDs) in the corpus

"""

import csv
import os
import random
from collections import defaultdict

from sentence_transformers.util import http_get

random.seed(42)

# Get raw file
source_file = "quora-IR-dataset/quora_duplicate_questions.tsv"
os.makedirs("quora-IR-dataset", exist_ok=True)
os.makedirs("quora-IR-dataset/graph", exist_ok=True)
os.makedirs("quora-IR-dataset/information-retrieval", exist_ok=True)
os.makedirs("quora-IR-dataset/classification", exist_ok=True)
os.makedirs("quora-IR-dataset/duplicate-mining", exist_ok=True)

if not os.path.exists(source_file):
    print("Download file to", source_file)
    http_get("http://qim.fs.quoracdn.net/quora_duplicate_questions.tsv", source_file)

# Read pairwise file
sentences = {}
duplicates = defaultdict(lambda: defaultdict(bool))
rows = []
with open(source_file, encoding="utf8") as fIn:
    reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_MINIMAL)
    for row in reader:
        id1 = row["qid1"]
        id2 = row["qid2"]
        question1 = row["question1"].replace("\r", "").replace("\n", " ").replace("\t", " ")
        question2 = row["question2"].replace("\r", "").replace("\n", " ").replace("\t", " ")
        is_duplicate = row["is_duplicate"]

        if question1 == "" or question2 == "":
            continue

        sentences[id1] = question1
        sentences[id2] = question2

        rows.append(
            {"qid1": id1, "qid2": id2, "question1": question1, "question2": question2, "is_duplicate": is_duplicate}
        )

        if is_duplicate == "1":
            duplicates[id1][id2] = True
            duplicates[id2][id1] = True


# Search for (near) exact duplicates
# The original Quora duplicate questions dataset is an incomplete annotation,
# i.e., there are several duplicate question pairs which are not marked as duplicates.
# These missing annotation can make it difficult to compare approaches.
# Here we use a simple approach that searches for near identical questions, that only differ in maybe a stopword
# We mark these found question pairs also as duplicate to increase the annotation coverage
stopwords = set(
    [
        "a",
        "about",
        "above",
        "after",
        "again",
        "against",
        "ain",
        "all",
        "am",
        "an",
        "and",
        "any",
        "are",
        "aren",
        "aren't",
        "as",
        "at",
        "be",
        "because",
        "been",
        "before",
        "being",
        "below",
        "between",
        "both",
        "but",
        "by",
        "can",
        "couldn",  # codespell:ignore couldn
        "couldn't",
        "d",
        "did",
        "didn",
        "didn't",
        "do",
        "does",
        "doesn",
        "doesn't",
        "doing",
        "don",
        "don't",
        "down",
        "during",
        "each",
        "few",
        "for",
        "from",
        "further",
        "had",
        "hadn",
        "hadn't",
        "has",
        "hasn",
        "hasn't",
        "have",
        "haven",
        "haven't",
        "having",
        "he",
        "her",
        "here",
        "hers",
        "herself",
        "him",
        "himself",
        "his",
        "i",
        "if",
        "in",
        "into",
        "is",
        "isn",
        "isn't",
        "it's",
        "its",
        "itself",
        "just",
        "ll",
        "m",
        "ma",
        "me",
        "mightn",
        "mightn't",
        "more",
        "most",
        "mustn",
        "mustn't",
        "my",
        "myself",
        "needn",
        "needn't",
        "no",
        "nor",
        "not",
        "now",
        "o",
        "of",
        "off",
        "on",
        "once",
        "only",
        "or",
        "other",
        "our",
        "ours",
        "ourselves",
        "out",
        "over",
        "own",
        "re",
        "s",
        "same",
        "shan",
        "shan't",
        "she",
        "she's",
        "should",
        "should've",
        "shouldn",
        "shouldn't",
        "so",
        "some",
        "such",
        "t",
        "than",
        "that",
        "that'll",
        "the",
        "their",
        "theirs",
        "them",
        "themselves",
        "then",
        "there",
        "these",
        "they",
        "this",
        "those",
        "through",
        "to",
        "too",
        "under",
        "until",
        "up",
        "ve",
        "very",
        "was",
        "wasn",
        "wasn't",
        "we",
        "were",
        "weren",
        "weren't",
        "which",
        "while",
        "will",
        "with",
        "won",
        "won't",
        "wouldn",
        "wouldn't",
        "y",
        "you",
        "you'd",
        "you'll",
        "you're",
        "you've",
        "your",
        "yours",
        "yourself",
        "yourselves",
    ]
)

num_new_duplicates = 0
sentences_norm = {}

for id, sent in sentences.items():
    sent_norm = sent.lower()

    # Replace some common paraphrases
    sent_norm = sent_norm.replace("how do you", "how do i").replace("how do we", "how do i")
    sent_norm = (
        sent_norm.replace("how can we", "how can i")
        .replace("how can you", "how can i")
        .replace("how can i", "how do i")
    )
    sent_norm = sent_norm.replace("really true", "true")
    sent_norm = sent_norm.replace("what are the importance", "what is the importance")
    sent_norm = sent_norm.replace("what was", "what is")
    sent_norm = sent_norm.replace("so many", "many")
    sent_norm = sent_norm.replace("would it take", "will it take")

    # Remove any punctuation characters
    for c in [",", "!", ".", "?", "'", '"', ":", ";", "[", "]", "{", "}", "<", ">"]:
        sent_norm = sent_norm.replace(c, " ")

    # Remove stop words
    tokens = sent_norm.split()
    tokens = [token for token in tokens if token not in stopwords]
    sent_norm = "".join(tokens)

    if sent_norm in sentences_norm:
        if not duplicates[id][sentences_norm[sent_norm]]:
            num_new_duplicates += 1

        duplicates[id][sentences_norm[sent_norm]] = True
        duplicates[sentences_norm[sent_norm]][id] = True
    else:
        sentences_norm[sent_norm] = id


print("(Nearly) exact duplicates found:", num_new_duplicates)


# Add transitive closure (if a,b and b,c duplicates => a,c are duplicates)
new_entries = True
while new_entries:
    print("Add transitive closure")
    new_entries = False
    for a in sentences:
        for b in list(duplicates[a]):
            for c in list(duplicates[b]):
                if a != c and not duplicates[a][c]:
                    new_entries = True
                    duplicates[a][c] = True
                    duplicates[c][a] = True


# Distribute rows to train/dev/test split
# Ensure that sets contain distinct sentences
is_assigned = set()
random.shuffle(rows)

train_ids = set()
dev_ids = set()
test_ids = set()

counter = 0
for row in rows:
    if row["qid1"] in is_assigned and row["qid2"] in is_assigned:
        continue
    elif row["qid1"] in is_assigned or row["qid2"] in is_assigned:
        if row["qid2"] in is_assigned:  # Ensure that qid1 is assigned and qid2 not yet
            row["qid1"], row["qid2"] = row["qid2"], row["qid1"]

        # Move qid2 to the same split as qid1
        target_set = train_ids
        if row["qid1"] in dev_ids:
            target_set = dev_ids
        elif row["qid1"] in test_ids:
            target_set = test_ids

    else:
        # Distribution about 85%/5%/10%
        target_set = train_ids
        if counter % 10 == 0:
            target_set = dev_ids
        elif counter % 10 == 1 or counter % 10 == 2:
            target_set = test_ids
        counter += 1

    # Get the sentence with all duplicates and add it to the respective sets
    target_set.add(row["qid1"])
    is_assigned.add(row["qid1"])

    target_set.add(row["qid2"])
    is_assigned.add(row["qid2"])

    for b in list(duplicates[row["qid1"]]) + list(duplicates[row["qid2"]]):
        target_set.add(b)
        is_assigned.add(b)


# Assert all sets are mutually exclusive
assert len(train_ids.intersection(dev_ids)) == 0
assert len(train_ids.intersection(test_ids)) == 0
assert len(test_ids.intersection(dev_ids)) == 0


print("\nTrain sentences:", len(train_ids))
print("Dev sentences:", len(dev_ids))
print("Test sentences:", len(test_ids))


# Extract the ids for duplicate questions for train/dev/test
def get_duplicate_set(ids_set):
    dups_set = set()
    for a in ids_set:
        for b in duplicates[a]:
            ids = sorted([a, b])
            dups_set.add(tuple(ids))
    return dups_set


train_duplicates = get_duplicate_set(train_ids)
dev_duplicates = get_duplicate_set(dev_ids)
test_duplicates = get_duplicate_set(test_ids)


print("\nTrain duplicates", len(train_duplicates))
print("Dev duplicates", len(dev_duplicates))
print("Test duplicates", len(test_duplicates))

############### Write general files about the duplicate questions graph ############
with open("quora-IR-dataset/graph/sentences.tsv", "w", encoding="utf8") as fOut:
    fOut.write("qid\tquestion\n")
    for id, question in sentences.items():
        fOut.write(f"{id}\t{question}\n")

duplicates_list = set()
for a in duplicates:
    for b in duplicates[a]:
        duplicates_list.add(tuple(sorted([int(a), int(b)])))


duplicates_list = list(duplicates_list)
duplicates_list = sorted(duplicates_list, key=lambda x: x[0] * 1000000 + x[1])


print("\nWrite duplicate graph in pairwise format")
with open("quora-IR-dataset/graph/duplicates-graph-pairwise.tsv", "w", encoding="utf8") as fOut:
    fOut.write("qid1\tqid2\n")
    for a, b in duplicates_list:
        fOut.write(f"{a}\t{b}\n")


print("Write duplicate graph in list format")
with open("quora-IR-dataset/graph/duplicates-graph-list.tsv", "w", encoding="utf8") as fOut:
    fOut.write("qid1\tqid2\n")
    for a in sorted(duplicates.keys(), key=lambda x: int(x)):
        if len(duplicates[a]) > 0:
            fOut.write("{}\t{}\n".format(a, ",".join(sorted(duplicates[a]))))

print("Write duplicate graph in connected subgraph format")
with open("quora-IR-dataset/graph/duplicates-graph-connected-nodes.tsv", "w", encoding="utf8") as fOut:
    written_qids = set()
    fOut.write("qids\n")
    for a in sorted(duplicates.keys(), key=lambda x: int(x)):
        if a not in written_qids:
            ids = set()
            ids.add(a)

            for b in duplicates[a]:
                ids.add(b)

            fOut.write("{}\n".format(",".join(sorted(ids, key=lambda x: int(x)))))
            for id in ids:
                written_qids.add(id)


def write_qids(name, ids_list):
    with open("quora-IR-dataset/graph/" + name + "-questions.tsv", "w", encoding="utf8") as fOut:
        fOut.write("qid\n")
        fOut.write("\n".join(sorted(ids_list, key=lambda x: int(x))))


write_qids("train", train_ids)
write_qids("dev", dev_ids)
write_qids("test", test_ids)


####### Output for duplicate mining #######
def write_mining_files(name, ids, dups):
    with open("quora-IR-dataset/duplicate-mining/" + name + "_corpus.tsv", "w", encoding="utf8") as fOut:
        fOut.write("qid\tquestion\n")
        for id in ids:
            fOut.write(f"{id}\t{sentences[id]}\n")

    with open("quora-IR-dataset/duplicate-mining/" + name + "_duplicates.tsv", "w", encoding="utf8") as fOut:
        fOut.write("qid1\tqid2\n")
        for a, b in dups:
            fOut.write(f"{a}\t{b}\n")


write_mining_files("train", train_ids, train_duplicates)
write_mining_files("dev", dev_ids, dev_duplicates)
write_mining_files("test", test_ids, test_duplicates)


###### Classification dataset #####
with (
    open("quora-IR-dataset/classification/train_pairs.tsv", "w", encoding="utf8") as fOutTrain,
    open("quora-IR-dataset/classification/dev_pairs.tsv", "w", encoding="utf8") as fOutDev,
    open("quora-IR-dataset/classification/test_pairs.tsv", "w", encoding="utf8") as fOutTest,
):
    fOutTrain.write("\t".join(["qid1", "qid2", "question1", "question2", "is_duplicate"]) + "\n")
    fOutDev.write("\t".join(["qid1", "qid2", "question1", "question2", "is_duplicate"]) + "\n")
    fOutTest.write("\t".join(["qid1", "qid2", "question1", "question2", "is_duplicate"]) + "\n")

    for row in rows:
        id1 = row["qid1"]
        id2 = row["qid2"]

        target = None
        if id1 in train_ids and id2 in train_ids:
            target = fOutTrain
        elif id1 in dev_ids and id2 in dev_ids:
            target = fOutDev
        elif id1 in test_ids and id2 in test_ids:
            target = fOutTest

        if target is not None:
            target.write("\t".join([row["qid1"], row["qid2"], sentences[id1], sentences[id2], row["is_duplicate"]]))
            target.write("\n")


####### Write files for Information Retrieval #####
num_dev_queries = 5000
num_test_queries = 10000

corpus_ids = train_ids.copy()
dev_queries = set()
test_queries = set()

# Create dev queries
rnd_dev_ids = sorted(list(dev_ids))
random.shuffle(rnd_dev_ids)

for a in rnd_dev_ids:
    if a not in corpus_ids:
        if len(dev_queries) < num_dev_queries and len(duplicates[a]) > 0:
            dev_queries.add(a)
        else:
            corpus_ids.add(a)

        for b in duplicates[a]:
            if b not in dev_queries:
                corpus_ids.add(b)

# Create test queries
rnd_test_ids = sorted(list(test_ids))
random.shuffle(rnd_test_ids)

for a in rnd_test_ids:
    if a not in corpus_ids:
        if len(test_queries) < num_test_queries and len(duplicates[a]) > 0:
            test_queries.add(a)
        else:
            corpus_ids.add(a)

        for b in duplicates[a]:
            if b not in test_queries:
                corpus_ids.add(b)

# Write output for information-retrieval
print("\nInformation Retrieval Setup")
print("Corpus size:", len(corpus_ids))
print("Dev queries:", len(dev_queries))
print("Test queries:", len(test_queries))

with open("quora-IR-dataset/information-retrieval/corpus.tsv", "w", encoding="utf8") as fOut:
    fOut.write("qid\tquestion\n")
    for id in sorted(corpus_ids, key=lambda id: int(id)):
        fOut.write(f"{id}\t{sentences[id]}\n")

with open("quora-IR-dataset/information-retrieval/dev-queries.tsv", "w", encoding="utf8") as fOut:
    fOut.write("qid\tquestion\tduplicate_qids\n")
    for id in sorted(dev_queries, key=lambda id: int(id)):
        fOut.write("{}\t{}\t{}\n".format(id, sentences[id], ",".join(duplicates[id])))

with open("quora-IR-dataset/information-retrieval/test-queries.tsv", "w", encoding="utf8") as fOut:
    fOut.write("qid\tquestion\tduplicate_qids\n")
    for id in sorted(test_queries, key=lambda id: int(id)):
        fOut.write("{}\t{}\t{}\n".format(id, sentences[id], ",".join(duplicates[id])))


print("--DONE--")
